import numpy as np
from tqdm import tqdm
import concurrent.futures
from functools import partial
import pickle
import os
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend to prevent plot windows from appearing
from faster_caching import *
from plot_utils import *

def run_single_policy(policy_func, a_list, C, xi=None, Q=None, predictor=None, forced=None, threshold=None):
    """Run a single cache policy with the appropriate parameters."""
    if policy_func.__name__ == 'tail_optimized_LRU_cache_policy':
        return policy_func(a_list, C, xi, Q, predictor=predictor, forced=forced)
    elif policy_func.__name__ == 'LRU_cache_policy':
        return policy_func(a_list, C, forced)
    elif policy_func.__name__ == 'thre_lru_cache_policy':
        return policy_func(a_list, C, threshold, forced)
    else:
        raise ValueError(f"Unknown policy function: {policy_func.__name__}")

def count_turns_above_latency_threshold(uncached_tokens_list, latency_threshold):
    """
    Count the number of turns where latency exceeds the threshold.
    
    Args:
        uncached_tokens_list: List of uncached tokens for each turn
        latency_threshold: Latency threshold in seconds
        
    Returns:
        int: Number of turns with latency above the threshold
    """
    # Convert each uncached tokens count to latency
    latencies = [uncached_tokens_to_latency(tokens) for tokens in uncached_tokens_list]
    
    # Count turns where latency > threshold
    return sum(latency > latency_threshold for latency in latencies)

def process_cache_capacity(C, a_list, xi, Q, forced, latency_threshold):
    """Process a single cache capacity value using parallel execution for each policy."""
    
    # Define only the requested policy configurations
    policy_configs = [
        {'func': LRU_cache_policy, 'predictor': None, 'name': 'vanilla_lru', 'threshold': None},
        {'func': tail_optimized_LRU_cache_policy, 'predictor': 'None', 'name': 'lru', 'threshold': None},
        {'func': thre_lru_cache_policy, 'predictor': None, 'threshold': 1024, 'name': 'thre_lru'}
    ]
    
    # Execute each policy in parallel using ThreadPoolExecutor
    results = {}
    policy_results = {}
    
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Submit all policy execution tasks
        future_to_policy = {}
        for config in policy_configs:
            future = executor.submit(
                run_single_policy,
                policy_func=config['func'],
                a_list=a_list,
                C=C,
                xi=xi,
                Q=Q,
                predictor=config['predictor'],
                forced=forced,
                threshold=config['threshold']
            )
            future_to_policy[future] = config['name']
        
        # Collect results as they complete
        for future in concurrent.futures.as_completed(future_to_policy):
            policy_name = future_to_policy[future]
            try:
                policy_results[policy_name] = future.result()
            except Exception as exc:
                print(f"Policy {policy_name} generated an exception: {exc}")
                policy_results[policy_name] = None
    
    # Calculate turns above threshold for the latency threshold of 0.2s (200ms)
    for policy_name, uncached_results in policy_results.items():
        if uncached_results is not None:
            results[policy_name] = count_turns_above_latency_threshold(
                uncached_results, latency_threshold
            )
        else:
            results[policy_name] = np.nan  # Handle failed policy calculations
    
    # Store the capacity value for sorting results later
    results['capacity'] = C
    return results

def run_parallel_cache_evaluation(C_values, a_list, xi, Q, forced, latency_threshold, max_workers=None):
    """Run cache policy evaluation with two levels of parallelism."""
    
    # Create a partial function with fixed parameters
    process_fn = partial(
        process_cache_capacity,
        a_list=a_list,
        xi=xi,
        Q=Q,
        forced=forced,
        latency_threshold=latency_threshold
    )
    
    # Initialize result dictionaries
    policy_names = ['lru', 'vanilla_lru', 'thre_lru']
    results = {policy: [] for policy in policy_names}
    
    # Create progress bar for total tasks
    total_tasks = len(C_values)
    progress_bar = tqdm(total=total_tasks * len(policy_names), 
                        desc="Testing cache capacities")
    
    # Process capacities in parallel
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        # Submit all capacity processing tasks
        future_to_capacity = {executor.submit(process_fn, C): C for C in C_values}
        
        # Process results as they complete
        completed_capacities = []
        for future in concurrent.futures.as_completed(future_to_capacity):
            try:
                result = future.result()
                C = result['capacity']
                completed_capacities.append(C)
                
                # Update results dictionaries
                for policy in policy_names:
                    results[policy].append(result[policy])
                    # Update progress bar
                    progress_bar.update(1)
                
            except Exception as exc:
                print(f"Capacity processing generated an exception: {exc}")
                # Handle the error appropriately
    
    # Close the progress bar
    progress_bar.close()
    
    # Reorder results based on original C_values order
    if len(completed_capacities) == len(C_values):
        # Create a mapping from completed capacity to its position in the results
        capacity_to_position = {C: i for i, C in enumerate(completed_capacities)}
        
        # Create ordered indices
        ordered_indices = [capacity_to_position[C] for C in C_values]
        
        # Reorder all results
        for key in results:
            results[key] = [results[key][i] for i in ordered_indices]
    
    return results

def create_comparison_table(C_values, results, title, filename):
    """Create and save a table comparing the policies and improvement metrics."""
    # Create a DataFrame for comparison
    df = pd.DataFrame({
        'Cache Capacity': C_values,
        'T-LRU Turns > 200ms': results['lru'],
        'Vanilla LRU Turns > 200ms': results['vanilla_lru'],
        'Threshold LRU Turns > 200ms': results['thre_lru']
    })
    
    # Save to CSV file
    df.to_csv(filename, index=False)
    print(f"Comparison table saved to {filename}")
    return df

def save_results_to_pickle(C_values, results, save_path):
    """Save results to a pickle file for later analysis."""
    results_dict = {
        'C_values': C_values,
        'policy_results': results
    }
    
    with open(save_path, "wb") as f:
        pickle.dump(results_dict, f)
    print(f"Results saved to {save_path}")

def generate_latency_comparison_tables():
    """
    This function does not exist in the current file but is mentioned in the target_file parameter.
    No changes needed for this function, as it's actually managed in the other file.
    """
    pass

# Main execution
if __name__ == "__main__":
    # Parameter settings
    C_values = [1000, 2000, 4000, 6000, 8000, 10000]  # Cache capacity values
    xi_values = [694, 1498, 2302, 3107, 7934]  # 0.05, 0.1, 0.15, 0.2, 0.5 latency
    Q = 100  # Fixed Q value
    forced = 0  # Fixed forced parameter
    latency_threshold = 0.2  # Focus on 200ms threshold
    
    # Load data
    a_list = load_data("ShareGPT_easy")
    # Convert to dataframe
    a_list = pd.DataFrame(a_list)
    # Filter to include only conv_idx 1-100 (user's change)
    max_conv_idx = 200
    a_list = a_list[a_list['conv_idx'].isin(range(1, max_conv_idx + 1))]
    
    # Run experiments for each xi value
    for xi in xi_values:
        print(f"Running for xi={xi}, Q={Q}")
        set_name_run(f"xi{xi}Q{Q}forced{forced}_maxconvidx{max_conv_idx}")
        
        # Run the parallel evaluation
        results = run_parallel_cache_evaluation(
            C_values, 
            a_list, 
            xi, 
            Q, 
            forced, 
            latency_threshold,
            max_workers=None  # None means use all available CPU cores
        )
        
        # Create directory if it doesn't exist
        save_dir = f"./results/latency_200ms_comparison_xi{xi}Q{Q}forced{forced}_maxconvidx{max_conv_idx}"
        os.makedirs(save_dir, exist_ok=True)
        
        # Create and save comparison table
        table_filename = f"{save_dir}/latency_200ms_comparison.csv"
        df = create_comparison_table(
            C_values,
            results,
            f"Comparison of Turns with Latency > 200ms (xi={xi}, Q={Q})",
            table_filename
        )
        
        # Save results as pickle for later analysis
        pickle_path = f"{save_dir}/latency_200ms_comparison.pkl"
        save_results_to_pickle(C_values, results, pickle_path)
        
        print(f"Analysis completed for xi={xi}") 